/*
 *    OzaBoost.java
 *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
 *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */
package moa.classifiers;

import sizeof.agent.SizeOfAgent;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.options.ClassOption;
import moa.options.FlagOption;
import moa.options.IntOption;
import weka.core.Instance;

public class OzaBoost extends AbstractClassifier {

	private static final long serialVersionUID = 1L;

	public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
			"Classifier to train.", Classifier.class, "HoeffdingTree");

	public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
			"The number of models to boost.", 10, 1, Integer.MAX_VALUE);

	public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p',
			"Boost with weights only; no poisson.");

	protected Classifier[] ensemble;

	protected double[] scms;

	protected double[] swms;

	@Override
	public int measureByteSize() {
		int size = (int) SizeOfAgent.sizeOf(this);
		for (Classifier classifier : this.ensemble) {
			size += classifier.measureByteSize();
		}
		return size;
	}

	@Override
	public void resetLearningImpl() {
		this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
		Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
		baseLearner.resetLearning();
		for (int i = 0; i < this.ensemble.length; i++) {
			this.ensemble[i] = baseLearner.copy();
		}
		this.scms = new double[this.ensemble.length];
		this.swms = new double[this.ensemble.length];
	}

	@Override
	public void trainOnInstanceImpl(Instance inst) {
		double lambda_d = 1.0;
		for (int i = 0; i < this.ensemble.length; i++) {
			double k = this.pureBoostOption.isSet() ? lambda_d : MiscUtils
					.poisson(lambda_d, this.classifierRandom);
			if (k > 0.0) {
				Instance weightedInst = (Instance) inst.copy();
				weightedInst.setWeight(inst.weight() * k);
				this.ensemble[i].trainOnInstance(weightedInst);
			}
			if (this.ensemble[i].correctlyClassifies(inst)) {
				this.scms[i] += lambda_d;
				lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]);
			} else {
				this.swms[i] += lambda_d;
				lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]);
			}
		}
	}

	protected double getEnsembleMemberWeight(int i) {
		double em = this.swms[i] / (this.scms[i] + this.swms[i]);
		if ((em == 0.0) || (em > 0.5)) {
			return 0.0;
		}
		double Bm = em / (1.0 - em);
		return Math.log(1.0 / Bm);
	}

	public double[] getVotesForInstance(Instance inst) {
		DoubleVector combinedVote = new DoubleVector();
		for (int i = 0; i < this.ensemble.length; i++) {
			double memberWeight = getEnsembleMemberWeight(i);
			if (memberWeight > 0.0) {
				DoubleVector vote = new DoubleVector(this.ensemble[i]
						.getVotesForInstance(inst));
				if (vote.sumOfValues() > 0.0) {
					vote.normalize();
					vote.scaleValues(memberWeight);
					combinedVote.addValues(vote);
				}
			} else {
				break;
			}
		}
		return combinedVote.getArrayRef();
	}

	public boolean isRandomizable() {
		return true;
	}

	@Override
	public void getModelDescription(StringBuilder out, int indent) {
		// TODO Auto-generated method stub

	}

	@Override
	protected Measurement[] getModelMeasurementsImpl() {
		return new Measurement[] { new Measurement("ensemble size",
				this.ensemble != null ? this.ensemble.length : 0) };
	}

	@Override
	public Classifier[] getSubClassifiers() {
		return this.ensemble.clone();
	}

}
